/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/BSD-3-Clause
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is
 * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied. See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.master.task;

import com.tencent.angel.protobuf.generated.MLProtos.MatrixClock;
import com.tencent.angel.protobuf.generated.MLProtos.Pair;
import com.tencent.angel.protobuf.generated.WorkerMasterServiceProtos.TaskStateProto;
import com.tencent.angel.utils.CommonUtil;
import com.tencent.angel.worker.task.TaskId;
import it.unimi.dsi.fastutil.ints.Int2IntMap.Entry;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Manager for a single task.
 */

// TODO: 17/6/26 by zmyer
public class AMTask {
    private static final Log LOG = LogFactory.getLog(AMTask.class);
    /** task id */
    private final TaskId taskId;

    /** task state */
    private AMTaskState state;

    /** task iteration */
    private int iteration;

    /** matrix id to clock value map */
    private final Int2IntOpenHashMap matrixIdToClockMap;

    /** task run progress in current iteration */
    private float progress;

    /** task metrics */
    private final Map<String, String> metrics;

    /** task startup time */
    private long startTime;

    /** task finish time */
    private long finishTime;

    private final Lock readLock;
    private final Lock writeLock;

    public AMTask(TaskId id, AMTask amTask) {
        state = AMTaskState.NEW;
        taskId = id;
        metrics = new HashMap<>();
        startTime = -1;
        finishTime = -1;

        matrixIdToClockMap = new Int2IntOpenHashMap();
        // if amTask is not null, we should clone task state from it
        if (amTask == null) {
            iteration = 0;
            progress = 0.0f;
        } else {
            iteration = amTask.getIteration();
            progress = amTask.getProgress();
            matrixIdToClockMap.putAll(amTask.matrixIdToClockMap);
        }

        ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
        readLock = readWriteLock.readLock();
        writeLock = readWriteLock.writeLock();
    }

    /**
     * get task id
     *
     * @return TaskId task id
     */
    public TaskId getTaskId() {
        return taskId;
    }

    /**
     * get task state
     *
     * @return AMTaskState task state
     */
    public AMTaskState getState() {
        try {
            readLock.lock();
            return state;
        } finally {
            readLock.unlock();
        }
    }

    private void setState(AMTaskState newState) {
        try {
            writeLock.lock();
            switch (state) {
                case NEW: {
                    state = newState;
                    break;
                }
                case INITED: {
                    if (newState != AMTaskState.NEW) {
                        state = newState;
                    }
                    break;
                }
                case RUNNING:
                case WAITING: {
                    if (newState != AMTaskState.NEW && newState != AMTaskState.INITED) {
                        state = newState;
                        break;
                    }
                }
                case COMMITING: {
                    if (newState == AMTaskState.SUCCESS || newState == AMTaskState.FAILED
                        || newState == AMTaskState.KILLED) {
                        state = newState;
                    }
                    break;
                }

                default: {
                    break;
                }
            }
        } finally {
            writeLock.unlock();
        }
    }

    /**
     * get the task progress in current iteration
     *
     * @return float the task progress in current iteration
     */
    public float getProgress() {
        try {
            readLock.lock();
            return progress;
        } finally {
            readLock.unlock();
        }
    }

    /**
     * update task information: iteration, clocks, counters, progress, startup time and finish time
     * etc.
     *
     * @param taskStateProto task information from worker
     */
    public void updateTaskState(TaskStateProto taskStateProto) {
        try {
            writeLock.lock();
            LOG.debug("taskStateProto=" + taskStateProto);
            setState(transformState(taskStateProto.getState()));
            this.progress = taskStateProto.getProgress();
            if (taskStateProto.hasIteration()) {
                this.iteration = taskStateProto.getIteration();
            }
            updateMatrixClocks(taskStateProto.getMatrixClocksList());
            updateCounters(taskStateProto.getCountersList());

            if (taskStateProto.hasFinishTime()) {
                this.finishTime = taskStateProto.getFinishTime();
            }
            if (taskStateProto.hasStartTime()) {
                this.startTime = taskStateProto.getStartTime();
            }
        } finally {
            writeLock.unlock();
        }
    }

    private void updateCounters(List<Pair> counters) {
        for (Pair counter : counters) {
            metrics.put(counter.getKey(), counter.getValue());
        }
    }

    private void updateMatrixClocks(List<MatrixClock> matrixClocks) {
        for (MatrixClock matrixClock : matrixClocks) {
            matrixIdToClockMap.put(matrixClock.getMatrixId(), matrixClock.getClock());
        }
    }

    private AMTaskState transformState(String state) {
        switch (state) {
            case "NEW":
                return AMTaskState.NEW;
            case "INITED":
                return AMTaskState.INITED;
            case "RUNNING":
                return AMTaskState.RUNNING;
            case "WAITING":
                return AMTaskState.WAITING;
            case "COMMITING":
                return AMTaskState.COMMITING;
            case "SUCCESS":
                return AMTaskState.SUCCESS;
            case "FAILED":
                return AMTaskState.FAILED;
            case "KILLED":
                return AMTaskState.KILLED;
            default:
                return AMTaskState.NEW;
        }
    }

    /**
     * get task startup time
     *
     * @return long task startup time
     */
    public long getStartTime() {
        try {
            readLock.lock();
            return startTime;
        } finally {
            readLock.unlock();
        }
    }

    /**
     * get task finish time
     *
     * @return long task finish time
     */
    public long getFinishTime() {
        try {
            readLock.lock();
            return finishTime;
        } finally {
            readLock.unlock();
        }
    }

    /**
     * get the current iteration value
     *
     * @return int the current iteration value
     */
    public int getIteration() {
        try {
            readLock.lock();
            return iteration;
        } finally {
            readLock.unlock();
        }
    }

    @Override
    public String toString() {
        try {
            readLock.lock();
            return "AMTask [taskId=" + taskId + ", state=" + state + ", iteration=" + iteration
                + ", matrixIdToClockMap=" + matrixIdToClockMap.toString()
                + ", progress=" + progress
                + "]";
        } finally {
            readLock.unlock();
        }
    }

    /**
     * get task metrics
     *
     * @return Map<String String> task metrics
     */
    public Map<String, String> getMetrics() {
        return CommonUtil.getMetrics(metrics, readLock);
    }

    /**
     * update the clock value of the specified matrix
     *
     * @param matrixId matrix id
     * @param clock clock value
     */
    public void clock(int matrixId, int clock) {
        try {
            writeLock.lock();
            matrixIdToClockMap.put(matrixId, clock);
        } finally {
            writeLock.unlock();
        }
    }

    /**
     * update the task current iteration value
     *
     * @param iteration the task current iteration value
     */
    public void iteration(int iteration) {
        try {
            writeLock.lock();
            this.iteration = iteration;
        } finally {
            writeLock.unlock();
        }
    }

    /**
     * get all matrix clocks
     *
     * @return Int2IntOpenHashMap  all matrix clocks
     */
    public Int2IntOpenHashMap getMatrixClocks() {
        try {
            readLock.lock();
            return matrixIdToClockMap.clone();
        } finally {
            readLock.unlock();
        }
    }

    /**
     * get the clock of the specified matrix
     *
     * @param matrixId the matrix id
     * @return int the clock of the matrix
     */
    public int getMatrixClock(int matrixId) {
        try {
            readLock.lock();
            return matrixIdToClockMap.get(matrixId);
        } finally {
            readLock.unlock();
        }
    }

    /**
     * write the task state to a output stream
     *
     * @param output the output stream
     */
    public void serialize(DataOutputStream output) throws IOException {
        try {
            readLock.lock();
            output.writeUTF(state.toString());
            output.writeInt(iteration);
            output.writeInt(matrixIdToClockMap.size());
            for (Entry clockEntry : matrixIdToClockMap.int2IntEntrySet()) {
                output.writeInt(clockEntry.getIntKey());
                output.writeInt(clockEntry.getIntValue());
            }
        } finally {
            readLock.unlock();
        }
    }

    /**
     * read the task state from a input stream
     *
     * @param input the input stream
     */
    public void deserialize(DataInputStream input) throws IOException {
        try {
            writeLock.lock();
            state = transformState(input.readUTF());
            iteration = input.readInt();
            int clockNum = input.readInt();
            for (int i = 0; i < clockNum; i++) {
                matrixIdToClockMap.put(input.readInt(), input.readInt());
            }
        } finally {
            writeLock.unlock();
        }
    }

    /**
     * Reset some profiling counters
     */
    public void resetCounters() {
        metrics.put(TaskCounter.TOTAL_CALCULATE_SAMPLES, "0");
        metrics.put(TaskCounter.TOTAL_CALCULATE_TIME_MS, "0");
    }
}
